#!/usr/bin/env python3
"""
Phase 1: Data Assembly for CCA Tipping Point Paper
Merges full_panel + trilemma + deepening panels, constructs transition indicators,
and produces Table 1 (summary statistics by country group).
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS

# ── Paths ──────────────────────────────────────────────────────────────────
FOLLOWUP_DIR = PROJECT_DIR / "multilateral" / "followup"
CCA_DIR = PROJECT_DIR / "cca_tipping"
PROCESSED_DIR = CCA_DIR / "data" / "processed"
TABLE_DIR = CCA_DIR / "output" / "tables"
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
TABLE_DIR.mkdir(parents=True, exist_ok=True)

# ── Country Classifications ────────────────────────────────────────────────
CCA_ALL = {'ARM', 'AZE', 'BLR', 'GEO', 'KAZ', 'KGZ', 'MDA', 'MNG', 'RUS',
           'TJK', 'TKM', 'UKR', 'UZB'}
CCA_COMMODITY = {'AZE', 'KAZ', 'RUS', 'TKM', 'UZB'}
CCA_NON_COMMODITY = {'ARM', 'BLR', 'GEO', 'KGZ', 'MDA', 'MNG', 'TJK', 'UKR'}
BALTIC = {'EST', 'LVA', 'LTU'}
CEE = {'ALB', 'BGR', 'BIH', 'HRV', 'CZE', 'HUN', 'MKD', 'MNE', 'POL',
       'ROU', 'SRB', 'SVK', 'SVN'}
ALL_TRANSITION = CCA_ALL | BALTIC | CEE

OECD_MEMBERS = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA'
}

# ── Load & Merge Panels ───────────────────────────────────────────────────
print("=" * 70)
print("PHASE 1: DATA ASSEMBLY")
print("=" * 70)

# 1. Base panel (followup 140-country version)
fp = pd.read_csv(FOLLOWUP_DIR / "data" / "processed" / "full_panel.csv")
print(f"Followup panel: {fp['iso3'].nunique()} countries, {len(fp):,} obs")

# 2. Trilemma panel — governance regime indices
tri_cols = ['iso3', 'year', 'ers_index', 'fo_index', 'mi_index',
            'is_peg', 'is_float', 'eurozone']
tp = pd.read_csv(PROJECT_DIR / "trilemma" / "data" / "processed" / "trilemma_panel.csv",
                  usecols=[c for c in tri_cols if c != 'dummy'])
# Only keep trilemma-specific columns (avoid dupes)
tri_extra = [c for c in tp.columns if c not in fp.columns or c in ['iso3', 'year']]
tp = tp[tri_extra].drop_duplicates(subset=['iso3', 'year'])

# 3. Deepening panel — WGI governance + financial depth
wgi_cols = ['iso3', 'year', 'rule_of_law', 'regulatory_quality',
            'control_corruption', 'govt_effectiveness',
            'capital_per_worker', 'labsh']
dp = pd.read_csv(PROJECT_DIR / "capital_deepening" / "data" / "processed" / "deepening_panel.csv",
                  usecols=wgi_cols)
dp = dp.drop_duplicates(subset=['iso3', 'year'])

# Merge
panel = fp.merge(tp, on=['iso3', 'year'], how='left')
panel = panel.merge(dp, on=['iso3', 'year'], how='left')
print(f"Merged panel: {panel['iso3'].nunique()} countries, {len(panel):,} obs, {len(panel.columns)} columns")

# ── Construct Variables ────────────────────────────────────────────────────
print("\nConstructing variables...")

# Governance composite (mean of available WGI indicators)
wgi_vars = ['rule_of_law', 'regulatory_quality', 'control_corruption', 'govt_effectiveness']
panel['governance_composite'] = panel[wgi_vars].mean(axis=1)
n_gov = panel['governance_composite'].notna().sum()
print(f"  governance_composite: {n_gov:,} non-null obs")

# Financial depth proxy: use domestic credit/GDP if available, else KAOPEN as proxy
# We use gross_investment_gdp as a crude financial depth proxy since we lack credit/GDP
panel['financial_depth'] = panel['kaopen']  # KAOPEN is our best financial depth proxy

# Years since independence (1991 for post-Soviet, 1990 for CEE)
def years_since_independence(row):
    iso = row['iso3']
    yr = row['year']
    if iso in CCA_ALL or iso in BALTIC:
        return max(0, yr - 1991)
    elif iso in CEE:
        return max(0, yr - 1990)
    else:
        return np.nan

panel['years_since_indep'] = panel.apply(years_since_independence, axis=1)
n_ysi = panel['years_since_indep'].notna().sum()
print(f"  years_since_indep: {n_ysi:,} non-null obs")

# Transition index: standardized composite of governance + KAOPEN + financial_depth
# Only for transition countries (NaN for non-transition)
trans_mask = panel['iso3'].isin(ALL_TRANSITION)
for v in ['governance_composite', 'kaopen']:
    col = f'{v}_z'
    panel[col] = np.nan
    sub = panel.loc[trans_mask, v]
    if sub.notna().any():
        panel.loc[trans_mask, col] = (sub - sub.mean()) / sub.std()

panel['transition_index'] = panel[['governance_composite_z', 'kaopen_z']].mean(axis=1)

# Country group classification
def classify_country(iso):
    if iso in CCA_COMMODITY:
        return 'CCA_commodity'
    elif iso in CCA_NON_COMMODITY:
        return 'CCA_non_commodity'
    elif iso in BALTIC:
        return 'Baltic'
    elif iso in CEE:
        return 'CEE'
    elif iso in OECD_MEMBERS and iso not in ALL_TRANSITION:
        return 'OECD'
    else:
        return 'Rest'

panel['country_group'] = panel['iso3'].apply(classify_country)

# Binary indicators
panel['is_cca'] = panel['iso3'].isin(CCA_ALL).astype(float)
panel['is_cca_non_commodity'] = panel['iso3'].isin(CCA_NON_COMMODITY).astype(float)
panel['is_transition'] = panel['iso3'].isin(ALL_TRANSITION).astype(float)
panel['is_baltic'] = panel['iso3'].isin(BALTIC).astype(float)
panel['is_cee'] = panel['iso3'].isin(CEE).astype(float)

# Interaction terms
for z in ['Z_1', 'Z_2', 'Z_3']:
    panel[f'{z}_x_governance'] = panel[z] * panel['governance_composite']
    panel[f'{z}_x_transition'] = panel[z] * panel['is_transition']

print(f"\nGroup distribution:")
for grp, sub in panel.groupby('country_group'):
    print(f"  {grp:<20s}: {sub['iso3'].nunique():3d} countries, {len(sub):6,} obs")

# ── Filter to estimation sample ───────────────────────────────────────────
est = panel[
    (panel['ca_gdp'].notna()) &
    (panel['year'] >= 1986) &
    (panel['year'] <= 2024)
].copy()
print(f"\nEstimation sample (CA notna, 1986-2024): {est['iso3'].nunique()} countries, {len(est):,} obs")

# ── Save Panel ─────────────────────────────────────────────────────────────
panel.to_csv(PROCESSED_DIR / "cca_panel.csv", index=False)
print(f"\nSaved: {PROCESSED_DIR / 'cca_panel.csv'}")
print(f"  {panel['iso3'].nunique()} countries, {len(panel):,} obs, {len(panel.columns)} columns")

# ── Table 1: Summary Statistics by Group ───────────────────────────────────
print("\n" + "=" * 70)
print("TABLE 1: SUMMARY STATISTICS BY COUNTRY GROUP")
print("=" * 70)

stat_vars = [v for v in ['ca_gdp', 'Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep',
             'kaopen', 'governance_composite', 'nfa_gdp', 'rgdp_growth',
             'fiscal_bal_gdp', 'gdp_pc_ppp', 'life_expectancy']
             if v in est.columns]

groups = ['CCA_commodity', 'CCA_non_commodity', 'Baltic', 'CEE', 'OECD', 'Rest']

rows = []
for grp in groups:
    sub = est[est['country_group'] == grp]
    n_c = sub['iso3'].nunique()
    n_obs = len(sub)
    row = {'Group': grp, 'N_countries': n_c, 'N_obs': n_obs}
    for v in stat_vars:
        vals = sub[v].dropna()
        row[f'{v}_mean'] = vals.mean() if len(vals) > 0 else np.nan
        row[f'{v}_sd'] = vals.std() if len(vals) > 0 else np.nan
    rows.append(row)

# Add "All" row
row_all = {'Group': 'All', 'N_countries': est['iso3'].nunique(), 'N_obs': len(est)}
for v in stat_vars:
    vals = est[v].dropna()
    row_all[f'{v}_mean'] = vals.mean()
    row_all[f'{v}_sd'] = vals.std()
rows.append(row_all)

stats_df = pd.DataFrame(rows)
stats_df.to_csv(TABLE_DIR / "table1_summary_statistics.csv", index=False)

# Markdown table
all_var_labels = {
    'ca_gdp': 'CA/GDP (%)',
    'Z_1': 'Z₁ (demographics)',
    'Z_2': 'Z₂',
    'Z_3': 'Z₃',
    'old_dep': 'Old-age dep. ratio',
    'youth_dep': 'Youth dep. ratio',
    'kaopen': 'KAOPEN',
    'governance_composite': 'Governance (WGI)',
    'nfa_gdp': 'NFA/GDP',
    'rgdp_growth': 'Real GDP growth (%)',
    'fiscal_bal_gdp': 'Fiscal balance/GDP (%)',
    'gdp_pc_ppp': 'GDP per capita (PPP)',
    'life_expectancy': 'Life expectancy',
}
var_labels = {k: v for k, v in all_var_labels.items() if k in stat_vars}

md_lines = ["# Table 1: Summary Statistics by Country Group\n"]
header = "| Variable | " + " | ".join(f"{g} (n={stats_df[stats_df['Group']==g]['N_countries'].values[0]})" for g in groups) + " | All |"
sep = "|" + "---|" * (len(groups) + 2)
md_lines.append(header)
md_lines.append(sep)

for v, label in var_labels.items():
    cells = []
    for grp in groups + ['All']:
        r = stats_df[stats_df['Group'] == grp].iloc[0]
        m = r[f'{v}_mean']
        s = r[f'{v}_sd']
        if pd.notna(m):
            if v == 'gdp_pc_ppp':
                cells.append(f"{m:,.0f} ({s:,.0f})")
            else:
                cells.append(f"{m:.2f} ({s:.2f})")
        else:
            cells.append("—")
    md_lines.append(f"| {label} | " + " | ".join(cells) + " |")

md_text = "\n".join(md_lines)
(TABLE_DIR / "table1_summary_statistics.md").write_text(md_text)

print(md_text)
print(f"\nSaved: {TABLE_DIR / 'table1_summary_statistics.csv'}")
print(f"Saved: {TABLE_DIR / 'table1_summary_statistics.md'}")
print("\nPhase 1 complete.")
